# Introduction ----- 

# Code for making figure 4 of the paper, which will relate to analysis of the spindle-SO Coupling data

#Load up ----

pacman::p_load(tidyverse, patchwork, ghibli, eegUtils, 
               colorspace, circular, CircStats, tidybayes, bayestestR, modelr)


#Get our default settings
source("./eLife Submission Scripts/Analysis-Common-Utilities.R")


# Figure 4 ======

## Fig 4A: Example Data =======

# Load the example data
d_example = read_rds("./elife Submission Data/sleep_study_eeg_coupling_example.rds")


#Get the spindle and SO edges
spin_in  = which(d_example$spindleEdges > 0)
spin_out = which(d_example$spindleEdges < 1)

spin_onset   = spin_in[1]
spin_onset_t = d_example$xAx[spin_onset]

#The spindle is the whole length of the chunk so we dont get an offset here
# spin_offset  = 

#SO edges
so_in  = which(d_example$soEdges > 0)
so_out = which(d_example$soEdges < 1)

so_onset    = so_in[1]
so_onset_t  = d_example$xAx[so_onset]

so_offset   = so_out[so_out > so_onset][1]
so_offset_t = d_example$xAx[so_offset]

#Now make the plot
p_example = 
  d_example |>
  dplyr::select(xAx,raw,sigma,spinCoef,so,soPhase) |>
  pivot_longer(-c(xAx),names_to = "measure",values_to = "value") |>
  mutate(measure = factor(measure,levels = c("raw","sigma","spinCoef","so","soPhase"))) |>
  mutate(measure2 = ifelse(measure %in% c("raw","so"),1,measure)) |>
  ggplot(aes(x = xAx,y = value,group = measure)) +
  geom_line(colour = "black")  +
  facet_wrap(~measure2,ncol = 1,scales = "free_y") +
  geom_vline(xintercept = spin_onset_t, lty = 2, colour = "black") +
  geom_vline(xintercept = so_onset_t , lty = 3, colour = "grey20") +
  geom_vline(xintercept = so_offset_t, lty = 3, colour = "grey20") +
  theme_bw()+
  theme(strip.background = element_blank(),
        panel.grid = element_blank(),
        strip.text = element_text(colour = "grey20",size = 8),
        axis.text.y = element_text(colour = "grey20",size = 8),
        axis.text.x = element_text(colour = "grey20",size = 8),
        axis.title = element_text(colour = "black",size = 8),
        legend.title = element_blank(),
        legend.justification = "top",
        legend.position = "none") +
  labs(x = "Time (Seconds)")

#We can convert those vertical lines to coloured bars in Inkscape for the final plot

## Fig 4B: Coupling Angle Histograms -----

#Use the summary data
d_sum = 
  readr::read_rds("./eLife Submission Data/sleep_study_eeg_summary_data.rds")


### Stats on the angles ======

#A few things a reviewer asked for:

#Test of non-uniformity of coupling angles for each genotype:
d_r_test = 
  d_sum |>
  filter(measure %in% c("itpc_overlap_angle")) |>
  unnest(data) |>
  filter(electrode %in% c('Cz')) |>
  group_by(group) |>
  nest() |>
  mutate(circ_data = map(data, ~.x |>   
                           pull(value) |> 
                           circular(units = "degrees")), 
         c_mean  = map_dbl(circ_data,mean.circular),
         c_sd    = map_dbl(circ_data,sd.circular),
         r_test  = map(circ_data, rayleigh.test),
         r_table = map(r_test,~tibble(statistic = .x$statistic,p.value = .x$p.value))
         
  )

d_r_test |>
  unnest(r_table) |>
  dplyr::select(group,c_mean,c_sd,statistic,p.value)

#We can also compare groups: no difference 
watson.williams.test(d_r_test$circ_data)

### Make the plot =====

# Group phase histograms
p_phase_histogram =
  d_sum |>
  filter(measure %in% c("itpc_overlap_angle")) |>
  unnest(data) |>
  ungroup() |>
  filter(electrode %in% c('Cz')) |>
  mutate(group = factor(group,levels = c("Sib","22q"))) |>
  mutate(value = map_dbl(value,~ifelse(.x < 0,360+.x,.x))) |>
  ggplot(aes(value,colour = group, fill = group)) +
  geom_histogram(breaks = seq(0,360, by = 45))+
  geom_vline(xintercept = 0  , colour = "grey20", lty = 3,size = 0.2) +
  geom_vline(xintercept = 90 , colour = "grey20", lty = 3,size = 0.2) +
  geom_vline(xintercept = 180, colour = "grey20", lty = 3,size = 0.2) +
  geom_vline(xintercept = 270, colour = "grey20", lty = 3,size = 0.2) +
  geom_vline(data = d_r_test, aes(xintercept = c_mean), colour = "black", lty = 2, size = 0.2) +
  facet_wrap(~group , ncol = 1) +
  scale_x_continuous(breaks = c(0,90,180,270,360))+
  scale_y_continuous(breaks = c(0,4,8,12,16)) +
  scale_fill_manual(values = cols) +
  scale_colour_manual(values = cols) +
  theme_bw() +
  theme(panel.grid = element_blank(),
        strip.background = element_blank(),
        strip.text = element_text(colour = "grey20",size = 8),
        axis.text.y = element_text(colour = "grey20",size = 8),
        axis.text.x = element_text(colour = "grey20",size = 8),
        axis.title = element_text(colour = "black",size = 8),
        legend.title = element_blank(),
        legend.justification = "top",
        legend.position = "none" ) +
  labs(x = "SO Phase (Degrees)", y = "Count")



## Fig 4C: Topoplots ----- 

#Load the GAMM posterior dataset
d_gamm = read_rds("./eLife Submission Data/sleep_study_topoplot_posterior_data.rds")


#Prepare plot labels
coupling_details <- 
  c(spinOverlap_z      = "Overlap (Z score)",
    itpc_overlap_mag_z = "MRL - Overlap (Z score)")


#Make the group difference topoplot
topo_coupling_diff =
  d_gamm |>
  ungroup() |>
  dplyr::select(measure,post_draws) |>
  filter(measure %in% c("spinOverlap_z","itpc_overlap_mag_z" )) |>
  mutate(measure = factor(measure,levels = c("spinOverlap_z","itpc_overlap_mag_z"))) |>
  unnest(post_draws) |>
  mutate(incircle = sqrt(x ^ 2 + y ^ 2) < circ_scale) %>%
  filter(incircle) %>%
  ggplot(aes(x = x, y = y, fill = diff)) +
  geom_raster(aes(alpha = pv)) +
  geom_mask(r = circ_scale, size = 0.5) +
  geom_head(r = circ_scale,size = 0.5) +                
  scale_fill_distiller(palette = "RdBu", 
                       limits = c(-1,1.5),
                       oob = scales::squish) +
  coord_equal()+
  theme_void() +
  scale_alpha(range = c(0.1, 1)) +
  facet_wrap(~measure,ncol = 2,
             labeller = labeller(measure = coupling_details),
             strip.position = "top") +
  theme(legend.position = "bottom",
        strip.text = element_text(colour = "grey20",size = 8,angle = 0),
        legend.text = element_text(size = 6)) +
  guides(alpha = "none")


## Fig 4D: Angular data -----


#Load the coupling angle GAMM data
d_vm_gamm = read_rds("./elife Submission Data/sleep_study_topoplot_angle_data.rds")

#Prepare the plot
p_angle_diff_vm = 
  d_vm_gamm |>
  dplyr::select(x,y,diff,pv) |>
  mutate(diff = (diff*180)/pi)  |>
  mutate(incircle = sqrt(x ^ 2 + y ^ 2) < circ_scale) |>
  filter(incircle) |>
  ggplot(aes(x = x, y = y, fill = diff)) +
  geom_raster(aes(alpha = pv)) +
  geom_mask(r = circ_scale, size = 0.5) +
  geom_head(r = circ_scale,size = 0.5) +   
  scale_fill_distiller(palette = "RdBu", limits = c(-60,60),
                       oob = scales::squish) +
  coord_equal()+
  theme_void() +
  scale_alpha(range = c(0.1, 1)) +
  theme(legend.position = "bottom",
        strip.text = element_text(colour = "grey20",size = 8,angle = -90),
        legend.text = element_text(size = 6)) +
  guides(alpha = "none")


#Assemble plot elemements
tp1 = (topo_coupling_diff / p_angle_diff_vm) + plot_layout(heights = c(2,1))


## Save ====

#Final assembly
fig_4 = (p_example | p_phase_histogram | tp1) + plot_annotation(tag_levels = "A")

ggsave("./Figures/figure_4.pdf",plot = fig_4, width = 18, height = 16, units = "cm")



# Figure 4 Supplement 1: Individual coupling data ------


#We have put individual data in boxplots into a supplementary figures for figures 2 + 3, so we keep that pattern here
d_sum_individual = 
  d_sum |>
  filter(measure %in% c("spinOverlap_z","itpc_overlap_mag_z"))


#Set up plot labels
d_sum_individual$plot_labels = c("Spindle - SW Overlap (z-scored)",
                                 "Spindle - SW MRL (z-scored)")

## Fig 4 - S1A-B ======

#Boxplots
p_individual_1 = 
  d_sum_individual |>
  ungroup() |>
  unnest(data) |>
  filter(electrode == "Cz") |>
  group_by(measure, plot_labels) |>
  nest() |>
  mutate(plot = map2(data,plot_labels,
                     ~ggplot(data = .x, aes(x = group,y = value, fill = group,colour = group)) +
                       geom_boxplot(alpha = 0.2,
                                    lwd = 0.25,
                                    outlier.color = NA,
                                    outlier.fill = NA) +
                       geom_point(size = 0.5,
                                  position = ggforce::position_jitternormal(sd_x = 0.05, sd_y = 0),alpha = 0.6) +
                       scale_fill_manual(values = cols) +
                       scale_colour_manual(values = cols) +
                       theme_bw() +
                       theme(strip.background = element_blank(),
                             strip.text.x = element_blank(),
                             panel.grid = element_blank(),
                             strip.text = element_text(colour = "grey20",size = 8),
                             axis.text.y = element_text(colour = "grey20",size = 8),
                             axis.text.x = element_text(colour = "grey20",size = 8),
                             axis.title = element_text(colour = "black",size = 8),
                             legend.title = element_blank(),
                             legend.justification = "top",
                             legend.position = "none" ) +
                       labs(x = "Group", y = .y))) |>
  pull(plot) |>
  wrap_plots(ncol = 2)


## Fig 4 - S1C-D: Scatter with age =======

#Scatter plots
p_individual_2 = 
  d_sum_individual |>
  ungroup() |>
  unnest(data) |>
  filter(electrode == "Cz") |>
  group_by(measure, plot_labels) |>
  nest() |>
  mutate(plot = map2(data,plot_labels,
                     ~ggplot(data = .x, aes(x = age_eeg,y = value, fill = group,colour = group)) +
                       geom_point(size = 0.5,alpha = 0.6) +
                       geom_smooth(size = 0.5,
                                   formula = y ~ x,method = "lm") +
                       scale_fill_manual(values = cols) +
                       scale_colour_manual(values = cols) +
                       theme_bw() +
                       theme(strip.background = element_blank(),
                             strip.text.x = element_blank(),
                             panel.grid = element_blank(),
                             strip.text = element_text(colour = "grey20",size = 8),
                             axis.text.y = element_text(colour = "grey20",size = 8),
                             axis.text.x = element_text(colour = "grey20",size = 8),
                             axis.title = element_text(colour = "black",size = 8),
                             legend.title = element_blank(),
                             legend.justification = "top",
                             legend.position = "none" ) +
                       labs(x = "Age", y = .y))) |>
  pull(plot) |>
  wrap_plots(ncol = 2)

#Assemble
p_individual = p_individual_1 / p_individual_2  + plot_annotation(tag_levels = "A")

## Save =====

ggsave("./Figures/figure_4_supplement_1.pdf",plot = p_individual, width = 10, height = 8, units = "cm")


#Figure 4 Supplement 2: Group topoplots ======

## Coupling Data =====

#Make the topoplots
topo_event_plots = 
  d_sum_individual %>% 
  unnest(data) |>
  ungroup() |>
  mutate(group = factor(group,levels = c("Sib","22q"))) %>%
  group_by(measure,plot_labels) |>
  nest() |>
  mutate(plots = map2(data,plot_labels, ~.x %>% ggplot(aes(x = x,
                                                           y = y,
                                                           z = value,
                                                           fill = value,
                                                           label = electrode)) +
                        geom_topo(grid_res = 200,
                                  colour = "white",
                                  size = 0.1,
                                  interp_limit = "head",
                                  chan_markers = "point",
                                  chan_size = 0.25,
                                  head_size = 0.5,
                                  method = "gam", breaks = 10) + 
                        scale_fill_viridis_c(option = "H")+
                        facet_wrap(~group, ncol = 1)+
                        theme_void() + 
                        coord_equal() + 
                        labs(subtitle = .y, fill = .y) +
                        theme(plot.subtitle = element_text(colour = "black",size = 8) ,
                              legend.position = "bottom",
                              legend.title = element_blank(),
                              legend.text = element_text(colour = "black",size = 6,angle = 40),
                              strip.text  = element_text(colour = "black",size = 8,angle = 0))) ) 


tp_s1 = cowplot::plot_grid(plotlist = topo_event_plots$plots,cols = 2)



## Angular data ====

d_angle = 
  d_sum |>
  filter(measure == "itpc_overlap_angle") |>
  unnest(data) |>
  ungroup()

#Convert age to a z score, convert angles to radians (from degrees)
d_angle2 = 
  d_angle |> 
  rename(itpc_overlap_angle = value) |>
  mutate(group = factor(group, c("Sib","22q"))) |>
  drop_na(itpc_overlap_angle) |>
  mutate(age_eeg  = zscore(age_eeg),
         itpc_overlap_angle = itpc_overlap_angle * pi/180)

# Make a plot from the mean angle data 
p_angle = 
  d_angle2 |>
  group_by(group,electrode) |>
  summarise(e_k = est.kappa(itpc_overlap_angle),
            mrl = est.rho(itpc_overlap_angle),
            m_a = circ.mean(itpc_overlap_angle),
            c_d = circ.disp(itpc_overlap_angle) |> as_tibble()) %>%
  left_join(topo, by = "electrode") %>%
  ungroup() %>%
  mutate(m_a = map_dbl(m_a,pracma::rad2deg)) %>%
  ggplot(aes(x = x,
             y = y,
             z = m_a,
             fill = m_a,
             label = electrode)) +
  geom_topo(grid_res = 200,
            interp_limit = "head",
            chan_markers = "point",
            chan_size = 0.25,
            head_size = 0.5,
            method = "gam",
            color = "black",
            breaks = 45) +
  # scale_fill_distiller(palette = "RdBu") + 
  facet_wrap(~group, ncol = 1)+
  theme_void() + 
  coord_equal() + 
  labs(fill = "Angle",subtitle = "ITPC Overlap - Angle") +
  scale_fill_gradientn(colours=rainbow_hcl(100,l = 70), limits = c(-180,180), breaks = c(-180,-90,0,90,180))



tp1 = cowplot::plot_grid(plotlist =list(topo_event_plots$plots[[1]],topo_event_plots$plots[[2]],p_angle), cols = 3)

## Save =====
ggsave("./Figures/figure_4_supplement_2.pdf",plot = tp1, width = 18, height = 12, units = "cm")
